import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import truncnorm
from tqdm import tqdm

# --- Utility Functions ---
def make_truncnorm(mean, std, lower, upper):
    a = (lower - mean) / std
    b = (upper - mean) / std
    return truncnorm(a, b, loc=mean, scale=std)

def build_Cmax(years, milestones):
    sorted_years = sorted(milestones.keys())
    cmax = np.zeros_like(years, dtype=float)
    for i in range(len(sorted_years) - 1):
        t1, t2 = sorted_years[i], sorted_years[i+1]
        c1, c2 = milestones[t1], milestones[t2]
        idx = (years >= t1) & (years <= t2)
        cmax[idx] = np.interp(years[idx], [t1, t2], [c1, c2])
    cmax[years < sorted_years[0]] = milestones[sorted_years[0]]
    cmax[years > sorted_years[-1]] = milestones[sorted_years[-1]]
    return cmax

def logistic_forecast(C0, b, Cmax_series):
    C_t = [C0]
    for t in range(1, len(Cmax_series)):
        Ct = C_t[-1]
        Cmax = Cmax_series[t]
        Ct_next = Ct + b * Ct * (1 - Ct / Cmax)
        C_t.append(Ct_next)
    return np.array(C_t)

def run_simulations(b_sampler, Cmax_series, C0_dist, n_sim=1000):
    results = []
    for _ in range(n_sim):
        C0 = C0_dist.rvs()
        b = b_sampler()
        try:
            Ct = logistic_forecast(C0, b, Cmax_series)
            results.append(Ct)
        except:
            continue
    return np.array(results)

# --- Parameters ---
years = np.arange(2025, 2051)
anticipation = 5
target_2050 = 894
target_2030 = 471
n_sim = 1000
tolerance = 0.005
threshold = target_2050 * (1 - tolerance)

# --- Demand Pull ---
demand_targets = {2025: 2.13, 2030: target_2030, 2050: target_2050}
shifted_targets = {year - anticipation: val for year, val in demand_targets.items()}
Cmax_series = build_Cmax(years, shifted_targets)

# --- Initial Capacity Distribution ---
C0_dist = make_truncnorm(mean=2.13, std=4.28, lower=0.26, upper=14.33)

# --- Step 1: Find minimum required growth rate based on median ---
mean_range = np.arange(20, 71, 1)
min_required_mean = None

for mean_b in tqdm(mean_range):
    sim_results = []
    for _ in range(n_sim):
        C0 = C0_dist.rvs()
        b = mean_b / 100  # Fixed b
        try:
            C_t = logistic_forecast(C0, b, Cmax_series)
            sim_results.append(C_t)
        except:
            continue
    sim_results = np.array(sim_results)
    if sim_results.shape[0] == 0:
        continue
    p50 = np.percentile(sim_results, 50, axis=0)
    p50_2050 = p50[np.where(years == 2050)[0][0]]
    if p50_2050 >= threshold:
        min_required_mean = mean_b
        break

# --- Step 2: Simulate base and sensitivity scenarios ---
b_dist_base = make_truncnorm(mean=39, std=11.86, lower=15, upper=70)
base_results = run_simulations(lambda: b_dist_base.rvs() / 100, Cmax_series, C0_dist, n_sim=n_sim)
sens_results = run_simulations(lambda: min_required_mean / 100, Cmax_series, C0_dist, n_sim=n_sim)

# --- Step 3: Compute Percentiles ---
p25_base, p50_base, p75_base = np.percentile(base_results, [25, 50, 75], axis=0)
p50_sens = np.percentile(sens_results, 50, axis=0)

# --- Step 4: Plot ---
plt.figure(figsize=(10, 6))
plt.fill_between(years, p25_base, p75_base, color='tab:orange', alpha=0.5, label='50% interval (Base)')
plt.plot(years, p50_base, color='tab:orange', label='Median (Base)')
plt.plot(years, p50_sens, color='black', linestyle='--', linewidth=2,
         label=f'Median – Sensitivity ({min_required_mean:.1f}%)')
plt.plot(years, Cmax_series, '--', color='gray', label='Median demand pull')

plt.axhline(target_2050, color='gray', linestyle=':')
plt.axvline(2050, color='black', linestyle=':')
plt.xlabel("Year", fontsize = 15)
plt.ylabel("PEM Electrolyzer Capacity [GW]", fontsize=15)
plt.title("PEM Capacity Forecast: Base vs Minimum Required Growth Rate (Median)")
plt.grid(True, linestyle='--', alpha=0.6)
plt.legend(fontsize =15)
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.tight_layout()
plt.show()

# --- Output ---
print(f" Minimum fixed growth rate required (Median): {min_required_mean:.1f}%/year")
